from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mxnet as mx
import numpy as np
from config import config

ACT_BIT = 1
bn_mom = 0.9
workspace = 256
memonger = False


def Conv(**kwargs):
    body = mx.sym.Convolution(**kwargs)
    return body


def Act(data, act_type, name):
    if act_type == 'prelu':
        body = mx.sym.LeakyReLU(data=data, act_type='prelu', name=name)
    else:
        body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
    return body


#def lin(data, num_filter, workspace, name, binarize, dcn):
#  bit = 1
#  if not binarize:
#    if not dcn:
#        conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
#                                      no_bias=True, workspace=workspace, name=name + '_conv')
#        bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
#        act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
#        return act1
#    else:
#        bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
#        act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
#        conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1,
#                num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
#        conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset,
#                num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
#        #conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
#        #                              no_bias=False, workspace=workspace, name=name + '_conv')
#        return conv1
#  else:
#    bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
#    act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
#    conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
#                               no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit)
#    conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
#    return conv1


def lin3(data, num_filter, workspace, name, k, g=1, d=1):
    if k != 3:
        conv1 = Conv(data=data,
                     num_filter=num_filter,
                     kernel=(k, k),
                     stride=(1, 1),
                     pad=((k - 1) // 2, (k - 1) // 2),
                     num_group=g,
                     no_bias=True,
                     workspace=workspace,
                     name=name + '_conv')
    else:
        conv1 = Conv(data=data,
                     num_filter=num_filter,
                     kernel=(k, k),
                     stride=(1, 1),
                     pad=(d, d),
                     num_group=g,
                     dilate=(d, d),
                     no_bias=True,
                     workspace=workspace,
                     name=name + '_conv')
    bn1 = mx.sym.BatchNorm(data=conv1,
                           fix_gamma=False,
                           momentum=bn_mom,
                           eps=2e-5,
                           name=name + '_bn')
    act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
    ret = act1
    return ret


def ConvFactory(data,
                num_filter,
                kernel,
                stride=(1, 1),
                pad=(0, 0),
                act_type="relu",
                mirror_attr={},
                with_act=True,
                dcn=False,
                name=''):
    if not dcn:
        conv = mx.symbol.Convolution(data=data,
                                     num_filter=num_filter,
                                     kernel=kernel,
                                     stride=stride,
                                     pad=pad,
                                     no_bias=True,
                                     workspace=workspace,
                                     name=name + '_conv')
    else:
        conv_offset = mx.symbol.Convolution(name=name + '_conv_offset',
                                            data=data,
                                            num_filter=18,
                                            pad=(1, 1),
                                            kernel=(3, 3),
                                            stride=(1, 1))
        conv = mx.contrib.symbol.DeformableConvolution(name=name + "_conv",
                                                       data=data,
                                                       offset=conv_offset,
                                                       num_filter=num_filter,
                                                       pad=(1, 1),
                                                       kernel=(3, 3),
                                                       num_deformable_group=1,
                                                       stride=stride,
                                                       dilate=(1, 1),
                                                       no_bias=False)
    bn = mx.symbol.BatchNorm(data=conv,
                             fix_gamma=False,
                             momentum=bn_mom,
                             eps=2e-5,
                             name=name + '_bn')
    if with_act:
        act = Act(bn, act_type, name=name + '_relu')
        #act = mx.symbol.Activation(
        #    data=bn, act_type=act_type, attr=mirror_attr, name=name+'_relu')
        return act
    else:
        return bn


class CAB:
    def __init__(self, data, nFilters, nModules, n, workspace, name, dilate,
                 group):
        self.data = data
        self.nFilters = nFilters
        self.nModules = nModules
        self.n = n
        self.workspace = workspace
        self.name = name
        self.dilate = dilate
        self.group = group
        self.sym_map = {}

    def get_output(self, w, h):
        key = (w, h)
        if key in self.sym_map:
            return self.sym_map[key]
        ret = None
        if h == self.n:
            if w == self.n:
                ret = (self.data, self.nFilters)
            else:
                x = self.get_output(w + 1, h)
                f = int(x[1] * 0.5)
                if w != self.n - 1:
                    body = lin3(x[0], f, self.workspace,
                                "%s_w%d_h%d_1" % (self.name, w, h), 3,
                                self.group, 1)
                else:
                    body = lin3(x[0], f, self.workspace,
                                "%s_w%d_h%d_1" % (self.name, w, h), 3,
                                self.group, self.dilate)
                ret = (body, f)
        else:
            x = self.get_output(w + 1, h + 1)
            y = self.get_output(w, h + 1)
            if h % 2 == 1 and h != w:
                xbody = lin3(x[0], x[1], self.workspace,
                             "%s_w%d_h%d_2" % (self.name, w, h), 3, x[1])
                #xbody = xbody+x[0]
            else:
                xbody = x[0]
            #xbody = x[0]
            #xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
            if w == 0:
                ybody = lin3(y[0], y[1], self.workspace,
                             "%s_w%d_h%d_3" % (self.name, w, h), 3, self.group)
            else:
                ybody = y[0]
            ybody = mx.sym.concat(y[0], ybody, dim=1)
            body = mx.sym.add_n(xbody,
                                ybody,
                                name="%s_w%d_h%d_add" % (self.name, w, h))
            body = body / 2
            ret = (body, x[1])
        self.sym_map[key] = ret
        return ret

    def get(self):
        return self.get_output(1, 1)[0]


def conv_resnet(data, num_filter, stride, dim_match, name, binarize, dcn,
                dilate, **kwargs):
    bit = 1
    #print('in unit2')
    # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
    bn1 = mx.sym.BatchNorm(data=data,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn1')
    if not binarize:
        act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
        conv1 = Conv(data=act1,
                     num_filter=int(num_filter * 0.5),
                     kernel=(1, 1),
                     stride=(1, 1),
                     pad=(0, 0),
                     no_bias=True,
                     workspace=workspace,
                     name=name + '_conv1')
    else:
        act1 = mx.sym.QActivation(data=bn1,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu1',
                                  backward_only=True)
        conv1 = mx.sym.QConvolution(data=act1,
                                    num_filter=int(num_filter * 0.5),
                                    kernel=(1, 1),
                                    stride=(1, 1),
                                    pad=(0, 0),
                                    no_bias=True,
                                    workspace=workspace,
                                    name=name + '_conv1',
                                    act_bit=ACT_BIT,
                                    weight_bit=bit)
    bn2 = mx.sym.BatchNorm(data=conv1,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn2')
    if not binarize:
        act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
        conv2 = Conv(data=act2,
                     num_filter=int(num_filter * 0.5),
                     kernel=(3, 3),
                     stride=(1, 1),
                     pad=(1, 1),
                     no_bias=True,
                     workspace=workspace,
                     name=name + '_conv2')
    else:
        act2 = mx.sym.QActivation(data=bn2,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu2',
                                  backward_only=True)
        conv2 = mx.sym.QConvolution(data=act2,
                                    num_filter=int(num_filter * 0.5),
                                    kernel=(3, 3),
                                    stride=(1, 1),
                                    pad=(1, 1),
                                    no_bias=True,
                                    workspace=workspace,
                                    name=name + '_conv2',
                                    act_bit=ACT_BIT,
                                    weight_bit=bit)
    bn3 = mx.sym.BatchNorm(data=conv2,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn3')
    if not binarize:
        act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
        conv3 = Conv(data=act3,
                     num_filter=num_filter,
                     kernel=(1, 1),
                     stride=(1, 1),
                     pad=(0, 0),
                     no_bias=True,
                     workspace=workspace,
                     name=name + '_conv3')
    else:
        act3 = mx.sym.QActivation(data=bn3,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu3',
                                  backward_only=True)
        conv3 = mx.sym.QConvolution(data=act3,
                                    num_filter=num_filter,
                                    kernel=(1, 1),
                                    stride=(1, 1),
                                    pad=(0, 0),
                                    no_bias=True,
                                    workspace=workspace,
                                    name=name + '_conv3',
                                    act_bit=ACT_BIT,
                                    weight_bit=bit)
    #if binarize:
    #  conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
    if dim_match:
        shortcut = data
    else:
        if not binarize:
            shortcut = Conv(data=act1,
                            num_filter=num_filter,
                            kernel=(1, 1),
                            stride=stride,
                            no_bias=True,
                            workspace=workspace,
                            name=name + '_sc')
        else:
            shortcut = mx.sym.QConvolution(data=act1,
                                           num_filter=num_filter,
                                           kernel=(1, 1),
                                           stride=stride,
                                           pad=(0, 0),
                                           no_bias=True,
                                           workspace=workspace,
                                           name=name + '_sc',
                                           act_bit=ACT_BIT,
                                           weight_bit=bit)
    if memonger:
        shortcut._set_attr(mirror_stage='True')
    return conv3 + shortcut


def conv_hpm(data, num_filter, stride, dim_match, name, binarize, dcn,
             dilation, **kwargs):
    bit = 1
    #print('in unit2')
    # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
    bn1 = mx.sym.BatchNorm(data=data,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn1')
    if not binarize:
        act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
        if not dcn:
            conv1 = Conv(data=act1,
                         num_filter=int(num_filter * 0.5),
                         kernel=(3, 3),
                         stride=(1, 1),
                         pad=(dilation, dilation),
                         dilate=(dilation, dilation),
                         no_bias=True,
                         workspace=workspace,
                         name=name + '_conv1')
        else:
            conv1_offset = mx.symbol.Convolution(name=name + '_conv1_offset',
                                                 data=act1,
                                                 num_filter=18,
                                                 pad=(1, 1),
                                                 kernel=(3, 3),
                                                 stride=(1, 1))
            conv1 = mx.contrib.symbol.DeformableConvolution(
                name=name + '_conv1',
                data=act1,
                offset=conv1_offset,
                num_filter=int(num_filter * 0.5),
                pad=(1, 1),
                kernel=(3, 3),
                num_deformable_group=1,
                stride=(1, 1),
                dilate=(1, 1),
                no_bias=True)
    else:
        act1 = mx.sym.QActivation(data=bn1,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu1',
                                  backward_only=True)
        conv1 = mx.sym.QConvolution_v1(data=act1,
                                       num_filter=int(num_filter * 0.5),
                                       kernel=(3, 3),
                                       stride=(1, 1),
                                       pad=(1, 1),
                                       no_bias=True,
                                       workspace=workspace,
                                       name=name + '_conv1',
                                       act_bit=ACT_BIT,
                                       weight_bit=bit)
    bn2 = mx.sym.BatchNorm(data=conv1,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn2')
    if not binarize:
        act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
        if not dcn:
            conv2 = Conv(data=act2,
                         num_filter=int(num_filter * 0.25),
                         kernel=(3, 3),
                         stride=(1, 1),
                         pad=(dilation, dilation),
                         dilate=(dilation, dilation),
                         no_bias=True,
                         workspace=workspace,
                         name=name + '_conv2')
        else:
            conv2_offset = mx.symbol.Convolution(name=name + '_conv2_offset',
                                                 data=act2,
                                                 num_filter=18,
                                                 pad=(1, 1),
                                                 kernel=(3, 3),
                                                 stride=(1, 1))
            conv2 = mx.contrib.symbol.DeformableConvolution(
                name=name + '_conv2',
                data=act2,
                offset=conv2_offset,
                num_filter=int(num_filter * 0.25),
                pad=(1, 1),
                kernel=(3, 3),
                num_deformable_group=1,
                stride=(1, 1),
                dilate=(1, 1),
                no_bias=True)
    else:
        act2 = mx.sym.QActivation(data=bn2,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu2',
                                  backward_only=True)
        conv2 = mx.sym.QConvolution_v1(data=act2,
                                       num_filter=int(num_filter * 0.25),
                                       kernel=(3, 3),
                                       stride=(1, 1),
                                       pad=(1, 1),
                                       no_bias=True,
                                       workspace=workspace,
                                       name=name + '_conv2',
                                       act_bit=ACT_BIT,
                                       weight_bit=bit)
    bn3 = mx.sym.BatchNorm(data=conv2,
                           fix_gamma=False,
                           eps=2e-5,
                           momentum=bn_mom,
                           name=name + '_bn3')
    if not binarize:
        act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
        if not dcn:
            conv3 = Conv(data=act3,
                         num_filter=int(num_filter * 0.25),
                         kernel=(3, 3),
                         stride=(1, 1),
                         pad=(dilation, dilation),
                         dilate=(dilation, dilation),
                         no_bias=True,
                         workspace=workspace,
                         name=name + '_conv3')
        else:
            conv3_offset = mx.symbol.Convolution(name=name + '_conv3_offset',
                                                 data=act3,
                                                 num_filter=18,
                                                 pad=(1, 1),
                                                 kernel=(3, 3),
                                                 stride=(1, 1))
            conv3 = mx.contrib.symbol.DeformableConvolution(
                name=name + '_conv3',
                data=act3,
                offset=conv3_offset,
                num_filter=int(num_filter * 0.25),
                pad=(1, 1),
                kernel=(3, 3),
                num_deformable_group=1,
                stride=(1, 1),
                dilate=(1, 1),
                no_bias=True)
    else:
        act3 = mx.sym.QActivation(data=bn3,
                                  act_bit=ACT_BIT,
                                  name=name + '_relu3',
                                  backward_only=True)
        conv3 = mx.sym.QConvolution_v1(data=act3,
                                       num_filter=int(num_filter * 0.25),
                                       kernel=(3, 3),
                                       stride=(1, 1),
                                       pad=(1, 1),
                                       no_bias=True,
                                       workspace=workspace,
                                       name=name + '_conv3',
                                       act_bit=ACT_BIT,
                                       weight_bit=bit)
    conv4 = mx.symbol.Concat(*[conv1, conv2, conv3])
    if binarize:
        conv4 = mx.sym.BatchNorm(data=conv4,
                                 fix_gamma=False,
                                 eps=2e-5,
                                 momentum=bn_mom,
                                 name=name + '_bn4')
    if dim_match:
        shortcut = data
    else:
        if not binarize:
            shortcut = Conv(data=act1,
                            num_filter=num_filter,
                            kernel=(1, 1),
                            stride=stride,
                            no_bias=True,
                            workspace=workspace,
                            name=name + '_sc')
        else:
            #assert(False)
            shortcut = mx.sym.QConvolution_v1(data=act1,
                                              num_filter=num_filter,
                                              kernel=(1, 1),
                                              stride=stride,
                                              pad=(0, 0),
                                              no_bias=True,
                                              workspace=workspace,
                                              name=name + '_sc',
                                              act_bit=ACT_BIT,
                                              weight_bit=bit)
            shortcut = mx.sym.BatchNorm(data=shortcut,
                                        fix_gamma=False,
                                        eps=2e-5,
                                        momentum=bn_mom,
                                        name=name + '_sc_bn')
    if memonger:
        shortcut._set_attr(mirror_stage='True')
    return conv4 + shortcut
    #return bn4 + shortcut
    #return act4 + shortcut


def block17(net,
            input_num_channels,
            scale=1.0,
            with_act=True,
            act_type='relu',
            mirror_attr={},
            name=''):
    tower_conv = ConvFactory(net, 192, (1, 1), name=name + '_conv')
    tower_conv1_0 = ConvFactory(net, 129, (1, 1), name=name + '_conv1_0')
    tower_conv1_1 = ConvFactory(tower_conv1_0,
                                160, (1, 7),
                                pad=(1, 2),
                                name=name + '_conv1_1')
    tower_conv1_2 = ConvFactory(tower_conv1_1,
                                192, (7, 1),
                                pad=(2, 1),
                                name=name + '_conv1_2')
    tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
    tower_out = ConvFactory(tower_mixed,
                            input_num_channels, (1, 1),
                            with_act=False,
                            name=name + '_conv_out')
    net = net + scale * tower_out
    if with_act:
        act = mx.symbol.Activation(data=net,
                                   act_type=act_type,
                                   attr=mirror_attr)
        return act
    else:
        return net


def block35(net,
            input_num_channels,
            scale=1.0,
            with_act=True,
            act_type='relu',
            mirror_attr={},
            name=''):
    M = 1.0
    tower_conv = ConvFactory(net,
                             int(input_num_channels * 0.25 * M), (1, 1),
                             name=name + '_conv')
    tower_conv1_0 = ConvFactory(net,
                                int(input_num_channels * 0.25 * M), (1, 1),
                                name=name + '_conv1_0')
    tower_conv1_1 = ConvFactory(tower_conv1_0,
                                int(input_num_channels * 0.25 * M), (3, 3),
                                pad=(1, 1),
                                name=name + '_conv1_1')
    tower_conv2_0 = ConvFactory(net,
                                int(input_num_channels * 0.25 * M), (1, 1),
                                name=name + '_conv2_0')
    tower_conv2_1 = ConvFactory(tower_conv2_0,
                                int(input_num_channels * 0.375 * M), (3, 3),
                                pad=(1, 1),
                                name=name + '_conv2_1')
    tower_conv2_2 = ConvFactory(tower_conv2_1,
                                int(input_num_channels * 0.5 * M), (3, 3),
                                pad=(1, 1),
                                name=name + '_conv2_2')
    tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
    tower_out = ConvFactory(tower_mixed,
                            input_num_channels, (1, 1),
                            with_act=False,
                            name=name + '_conv_out')

    net = net + scale * tower_out
    if with_act:
        act = mx.symbol.Activation(data=net,
                                   act_type=act_type,
                                   attr=mirror_attr)
        return act
    else:
        return net


def conv_inception(data, num_filter, stride, dim_match, name, binarize, dcn,
                   dilate, **kwargs):
    assert not binarize
    if stride[0] > 1 or not dim_match:
        return conv_resnet(data, num_filter, stride, dim_match, name, binarize,
                           dcn, dilate, **kwargs)
    conv4 = block35(data, num_filter, name=name + '_block35')
    return conv4


def conv_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate,
             **kwargs):
    if stride[0] > 1 or not dim_match:
        return conv_hpm(data, num_filter, stride, dim_match, name, binarize,
                        dcn, dilate, **kwargs)
    cab = CAB(data, num_filter, 1, 4, workspace, name, dilate, 1)
    return cab.get()


def conv_block(data, num_filter, stride, dim_match, name, binarize, dcn,
               dilate):
    if config.net_block == 'resnet':
        return conv_resnet(data, num_filter, stride, dim_match, name, binarize,
                           dcn, dilate)
    elif config.net_block == 'inception':
        return conv_inception(data, num_filter, stride, dim_match, name,
                              binarize, dcn, dilate)
    elif config.net_block == 'hpm':
        return conv_hpm(data, num_filter, stride, dim_match, name, binarize,
                        dcn, dilate)
    elif config.net_block == 'cab':
        return conv_cab(data, num_filter, stride, dim_match, name, binarize,
                        dcn, dilate)


def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn):
    s = 2
    _dcn = False
    up1 = data
    for i in range(nModules):
        up1 = conv_block(up1, nFilters, (1, 1), True, "%s_up1_%d" % (name, i),
                         binarize, _dcn, 1)
    low1 = mx.sym.Pooling(data=data,
                          kernel=(s, s),
                          stride=(s, s),
                          pad=(0, 0),
                          pool_type='max')
    for i in range(nModules):
        low1 = conv_block(low1, nFilters, (1, 1), True,
                          "%s_low1_%d" % (name, i), binarize, _dcn, 1)
    if n > 1:
        low2 = hourglass(low1, nFilters, nModules, n - 1, workspace,
                         "%s_%d" % (name, n - 1), binarize, dcn)
    else:
        low2 = low1
        for i in range(nModules):
            low2 = conv_block(low2, nFilters, (1, 1), True,
                              "%s_low2_%d" % (name, i), binarize, _dcn,
                              1)  #TODO
    low3 = low2
    for i in range(nModules):
        low3 = conv_block(low3, nFilters, (1, 1), True,
                          "%s_low3_%d" % (name, i), binarize, _dcn, 1)
    up2 = mx.symbol.UpSampling(low3,
                               scale=s,
                               sample_type='nearest',
                               workspace=512,
                               name='%s_upsampling_%s' % (name, n),
                               num_args=1)
    return mx.symbol.add_n(up1, up2)


class STA:
    def __init__(self, data, nFilters, nModules, n, workspace, name):
        self.data = data
        self.nFilters = nFilters
        self.nModules = nModules
        self.n = n
        self.workspace = workspace
        self.name = name
        self.sym_map = {}

    def get_conv(self, data, name, dilate=1, group=1):
        cab = CAB(data, self.nFilters, self.nModules, 4, self.workspace, name,
                  dilate, group)
        return cab.get()

    def get_output(self, w, h):
        #print(w,h)
        assert w >= 1 and w <= config.net_n + 1
        assert h >= 1 and h <= config.net_n + 1
        s = 2
        bn_mom = 0.9
        key = (w, h)
        if key in self.sym_map:
            return self.sym_map[key]
        ret = None
        if h == self.n:
            if w == self.n:
                ret = self.data, 64
            else:
                x = self.get_output(w + 1, h)
                body = self.get_conv(x[0], "%s_w%d_h%d_1" % (self.name, w, h))
                body = mx.sym.Pooling(data=body,
                                      kernel=(s, s),
                                      stride=(s, s),
                                      pad=(0, 0),
                                      pool_type='max')
                body = self.get_conv(body, "%s_w%d_h%d_2" % (self.name, w, h))
                ret = body, x[1] // 2
        else:
            x = self.get_output(w + 1, h + 1)
            y = self.get_output(w, h + 1)

            HC = False

            if h % 2 == 1 and h != w:
                xbody = lin3(x[0], self.nFilters, self.workspace,
                             "%s_w%d_h%d_x" % (self.name, w, h), 3,
                             self.nFilters, 1)
                HC = True
                #xbody = x[0]
            else:
                xbody = x[0]
            if x[1] // y[1] == 2:
                if w > 1:
                    ybody = mx.symbol.Deconvolution(
                        data=y[0],
                        num_filter=self.nFilters,
                        kernel=(s, s),
                        stride=(s, s),
                        name='%s_upsampling_w%d_h%d' % (self.name, w, h),
                        attr={'lr_mult': '1.0'},
                        workspace=self.workspace)
                    ybody = mx.sym.BatchNorm(data=ybody,
                                             fix_gamma=False,
                                             momentum=bn_mom,
                                             eps=2e-5,
                                             name="%s_w%d_h%d_y_bn" %
                                             (self.name, w, h))
                    ybody = Act(data=ybody,
                                act_type='relu',
                                name="%s_w%d_h%d_y_act" % (self.name, w, h))
                else:
                    if h >= 1:
                        ybody = mx.symbol.UpSampling(
                            y[0],
                            scale=s,
                            sample_type='nearest',
                            workspace=512,
                            name='%s_upsampling_w%d_h%d' % (self.name, w, h),
                            num_args=1)
                        ybody = self.get_conv(
                            ybody, "%s_w%d_h%d_4" % (self.name, w, h))
                    else:
                        ybody = mx.symbol.Deconvolution(
                            data=y[0],
                            num_filter=self.nFilters,
                            kernel=(s, s),
                            stride=(s, s),
                            name='%s_upsampling_w%d_h%d' % (self.name, w, h),
                            attr={'lr_mult': '1.0'},
                            workspace=self.workspace)
                        ybody = mx.sym.BatchNorm(data=ybody,
                                                 fix_gamma=False,
                                                 momentum=bn_mom,
                                                 eps=2e-5,
                                                 name="%s_w%d_h%d_y_bn" %
                                                 (self.name, w, h))
                        ybody = Act(data=ybody,
                                    act_type='relu',
                                    name="%s_w%d_h%d_y_act" %
                                    (self.name, w, h))
                        ybody = Conv(data=ybody,
                                     num_filter=self.nFilters,
                                     kernel=(3, 3),
                                     stride=(1, 1),
                                     pad=(1, 1),
                                     no_bias=True,
                                     name="%s_w%d_h%d_y_conv2" %
                                     (self.name, w, h),
                                     workspace=self.workspace)
                        ybody = mx.sym.BatchNorm(data=ybody,
                                                 fix_gamma=False,
                                                 momentum=bn_mom,
                                                 eps=2e-5,
                                                 name="%s_w%d_h%d_y_bn2" %
                                                 (self.name, w, h))
                        ybody = Act(data=ybody,
                                    act_type='relu',
                                    name="%s_w%d_h%d_y_act2" %
                                    (self.name, w, h))
            else:
                ybody = self.get_conv(y[0], "%s_w%d_h%d_5" % (self.name, w, h))
            #if not HC:
            if config.net_sta == 2 and h == 3 and w == 2:
                z = self.get_output(w + 1, h)
                zbody = z[0]
                zbody = mx.sym.Pooling(data=zbody,
                                       kernel=(z[1], z[1]),
                                       stride=(z[1], z[1]),
                                       pad=(0, 0),
                                       pool_type='avg')
                body = xbody + ybody
                body = body / 2
                body = mx.sym.broadcast_mul(body, zbody)
            else:  #sta==1
                body = xbody + ybody
                body = body / 2
            ret = body, x[1]

        assert ret is not None
        self.sym_map[key] = ret
        return ret

    def get(self):
        return self.get_output(1, 1)[0]


class SymCoherent:
    def __init__(self, per_batch_size):
        self.per_batch_size = per_batch_size
        self.flip_order = [
            16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 26, 25,
            24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, 32, 31,
            45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, 51, 50,
            49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65
        ]

    def get(self, data):
        #data.shape[0]==per_batch_size
        b = self.per_batch_size // 2
        ux = mx.sym.slice_axis(data, axis=0, begin=0, end=b)
        dx = mx.sym.slice_axis(data, axis=0, begin=b, end=b * 2)
        ux = mx.sym.flip(ux, axis=3)
        #ux = mx.sym.take(ux, indices = self.flip_order, axis=0)
        ux_list = []
        for o in self.flip_order:
            _ux = mx.sym.slice_axis(ux, axis=1, begin=o, end=o + 1)
            ux_list.append(_ux)
        ux = mx.sym.concat(*ux_list, dim=1)
        return ux, dx


def l2_loss(x, y):
    loss = x - y
    loss = mx.symbol.smooth_l1(loss, scalar=1.0)
    #loss = loss*loss
    loss = mx.symbol.mean(loss)
    return loss


def ce_loss(x, y):
    #loss = mx.sym.SoftmaxOutput(data = x, label = y, normalization='valid', multi_output=True)
    x_max = mx.sym.max(x, axis=[2, 3], keepdims=True)
    x = mx.sym.broadcast_minus(x, x_max)
    body = mx.sym.exp(x)
    sums = mx.sym.sum(body, axis=[2, 3], keepdims=True)
    body = mx.sym.broadcast_div(body, sums)
    loss = mx.sym.log(body)
    loss = loss * y * -1.0
    loss = mx.symbol.mean(loss, axis=[1, 2, 3])
    #loss = mx.symbol.mean(loss)
    return loss


def get_symbol(num_classes):
    m = config.multiplier
    sFilters = max(int(64 * m), 32)
    mFilters = max(int(128 * m), 32)
    nFilters = int(256 * m)

    nModules = 1
    nStacks = config.net_stacks
    binarize = config.net_binarize
    input_size = config.input_img_size
    label_size = config.output_label_size
    use_coherent = config.net_coherent
    use_STA = config.net_sta
    N = config.net_n
    DCN = config.net_dcn
    per_batch_size = config.per_batch_size
    print('binarize', binarize)
    print('use_coherent', use_coherent)
    print('use_STA', use_STA)
    print('use_N', N)
    print('use_DCN', DCN)
    print('per_batch_size', per_batch_size)
    #assert(label_size==64 or label_size==32)
    #assert(input_size==128 or input_size==256)
    coherentor = SymCoherent(per_batch_size)
    D = input_size // label_size
    print(input_size, label_size, D)
    data = mx.sym.Variable(name='data')
    data = data - 127.5
    data = data * 0.0078125
    gt_label = mx.symbol.Variable(name='softmax_label')
    losses = []
    closses = []
    ref_label = gt_label
    if D == 4:
        body = Conv(data=data,
                    num_filter=sFilters,
                    kernel=(7, 7),
                    stride=(2, 2),
                    pad=(3, 3),
                    no_bias=True,
                    name="conv0",
                    workspace=workspace)
    else:
        body = Conv(data=data,
                    num_filter=sFilters,
                    kernel=(3, 3),
                    stride=(1, 1),
                    pad=(1, 1),
                    no_bias=True,
                    name="conv0",
                    workspace=workspace)
    body = mx.sym.BatchNorm(data=body,
                            fix_gamma=False,
                            eps=2e-5,
                            momentum=bn_mom,
                            name='bn0')
    body = Act(data=body, act_type='relu', name='relu0')

    dcn = False
    body = conv_block(body, mFilters, (1, 1), sFilters == mFilters, 'res0',
                      False, dcn, 1)

    body = mx.sym.Pooling(data=body,
                          kernel=(2, 2),
                          stride=(2, 2),
                          pad=(0, 0),
                          pool_type='max')

    body = conv_block(body, mFilters, (1, 1), True, 'res1', False, dcn,
                      1)  #TODO
    body = conv_block(body, nFilters, (1, 1), mFilters == nFilters, 'res2',
                      binarize, dcn, 1)  #binarize=True?

    heatmap = None

    for i in range(nStacks):
        shortcut = body
        if config.net_sta > 0:
            sta = STA(body, nFilters, nModules, config.net_n + 1, workspace,
                      'sta%d' % (i))
            body = sta.get()
        else:
            body = hourglass(body, nFilters, nModules, config.net_n, workspace,
                             'stack%d_hg' % (i), binarize, dcn)
        for j in range(nModules):
            body = conv_block(body, nFilters, (1, 1), True,
                              'stack%d_unit%d' % (i, j), binarize, dcn, 1)
        _dcn = True if config.net_dcn >= 2 else False
        ll = ConvFactory(body,
                         nFilters, (1, 1),
                         dcn=_dcn,
                         name='stack%d_ll' % (i))
        _name = "heatmap%d" % (i) if i < nStacks - 1 else "heatmap"
        _dcn = True if config.net_dcn >= 2 else False
        if not _dcn:
            out = Conv(data=ll,
                       num_filter=num_classes,
                       kernel=(1, 1),
                       stride=(1, 1),
                       pad=(0, 0),
                       name=_name,
                       workspace=workspace)
        else:
            out_offset = mx.symbol.Convolution(name=_name + '_offset',
                                               data=ll,
                                               num_filter=18,
                                               pad=(1, 1),
                                               kernel=(3, 3),
                                               stride=(1, 1))
            out = mx.contrib.symbol.DeformableConvolution(
                name=_name,
                data=ll,
                offset=out_offset,
                num_filter=num_classes,
                pad=(1, 1),
                kernel=(3, 3),
                num_deformable_group=1,
                stride=(1, 1),
                dilate=(1, 1),
                no_bias=False)
            #out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1),
            #                          name=_name, workspace=workspace)
        if i == nStacks - 1:
            heatmap = out
        loss = ce_loss(out, ref_label)
        #loss = loss/nStacks
        #loss = l2_loss(out, ref_label)
        losses.append(loss)
        if config.net_coherent > 0:
            ux, dx = coherentor.get(out)
            closs = l2_loss(ux, dx)
            closs = closs / nStacks
            closses.append(closs)

        if i < nStacks - 1:
            ll2 = Conv(data=ll,
                       num_filter=nFilters,
                       kernel=(1, 1),
                       stride=(1, 1),
                       pad=(0, 0),
                       name="stack%d_ll2" % (i),
                       workspace=workspace)
            out2 = Conv(data=out,
                        num_filter=nFilters,
                        kernel=(1, 1),
                        stride=(1, 1),
                        pad=(0, 0),
                        name="stack%d_out2" % (i),
                        workspace=workspace)
            body = mx.symbol.add_n(shortcut, ll2, out2)
            _dcn = True if (config.net_dcn == 1
                            or config.net_dcn == 3) else False
            if _dcn:
                _name = "stack%d_out3" % (i)
                out3_offset = mx.symbol.Convolution(name=_name + '_offset',
                                                    data=body,
                                                    num_filter=18,
                                                    pad=(1, 1),
                                                    kernel=(3, 3),
                                                    stride=(1, 1))
                out3 = mx.contrib.symbol.DeformableConvolution(
                    name=_name,
                    data=body,
                    offset=out3_offset,
                    num_filter=nFilters,
                    pad=(1, 1),
                    kernel=(3, 3),
                    num_deformable_group=1,
                    stride=(1, 1),
                    dilate=(1, 1),
                    no_bias=False)
                body = out3

    pred = mx.symbol.BlockGrad(heatmap)
    #loss = mx.symbol.add_n(*losses)
    #loss = mx.symbol.MakeLoss(loss)
    #syms = [loss]
    syms = []
    for loss in losses:
        loss = mx.symbol.MakeLoss(loss)
        syms.append(loss)
    if len(closses) > 0:
        coherent_weight = 0.0001
        closs = mx.symbol.add_n(*closses)
        closs = mx.symbol.MakeLoss(closs, grad_scale=coherent_weight)
        syms.append(closs)
    syms.append(pred)
    sym = mx.symbol.Group(syms)
    return sym


def init_weights(sym, data_shape_dict):
    #print('in hg')
    arg_name = sym.list_arguments()
    aux_name = sym.list_auxiliary_states()
    arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
    arg_shape_dict = dict(zip(arg_name, arg_shape))
    aux_shape_dict = dict(zip(aux_name, aux_shape))
    #print(aux_shape)
    #print(aux_params)
    #print(arg_shape_dict)
    arg_params = {}
    aux_params = {}
    for k in arg_shape_dict:
        v = arg_shape_dict[k]
        #print(k,v)
        if k.endswith('offset_weight') or k.endswith('offset_bias'):
            print('initializing', k)
            arg_params[k] = mx.nd.zeros(shape=v)
        elif k.startswith('fc6_'):
            if k.endswith('_weight'):
                print('initializing', k)
                arg_params[k] = mx.random.normal(0, 0.01, shape=v)
            elif k.endswith('_bias'):
                print('initializing', k)
                arg_params[k] = mx.nd.zeros(shape=v)
        elif k.find('upsampling') >= 0:
            print('initializing upsampling_weight', k)
            arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
            init = mx.init.Initializer()
            init._init_bilinear(k, arg_params[k])
    return arg_params, aux_params
